Conversation
…kernel negative index handling Add use_pool_indexing constexpr to both small-batch and big-batch pretranspose decode kernels, enabling zero-copy state access directly from the pool via h0_indices, eliminating gather/scatter overhead. Also handle negative pool indices (padding slots) inside the kernel: blocks with negative indices skip computation and write zeros to output, removing the need for host-side torch.where remap (~37us/call savings). Combined effect: K-last decode is 4-5.6% faster than V-last at BS>=4.
… decode launchers
…ingle function Consolidate gated_delta_rule_decode_pretranspose_pooled into gated_delta_rule_decode_pretranspose by adding an optional state_indices parameter. When state_indices is provided, the kernel uses pool-indexed (zero-copy) mode; otherwise it uses direct 1:1 batch-to-state mapping. This eliminates ~175 lines of duplicated Python wrapper code while the underlying CUDA kernels remain unchanged. The compiled kernel cache key now includes pool_size and use_pool_indexing to ensure correct cache separation between the two modes.
When using pool indexing (state_indices), a non-contiguous state tensor could silently produce incorrect results because the kernel assumes contiguous memory layout for pointer arithmetic. Add an explicit assertion to catch this early.
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds pooled (indirect) state access to gated-delta-rule decode in Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant API as "gated_delta_rule_decode_pretranspose"
participant Cache as "KernelCache/_get_compiled_decode_kernel"
participant CUDA as "CUDA Kernel"
participant Pool as "State Pool / h0_source"
Client->>API: call gated_delta_rule_decode_pretranspose(state, state_indices?, ...)
API->>Cache: request compiled kernel (pool_size, use_pool_indexing, ...)
Cache-->>API: compiled kernel handle
API->>CUDA: launch kernel with grid shaped by grid_batch (B*HV) and use_pool_indexing
CUDA->>Pool: read state via state_indices (indirect) or direct mapping
CUDA-->>API: write outputs and updated state slots (respecting negative-index padding)
API-->>Client: return outputs and updated state buffer
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @xutizhou, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant optimization to the Gated Delta Rule (GDN) decode kernels by implementing a pooled indexing mechanism. This enhancement allows the kernels to directly access and update a shared state pool using indirect indices, thereby enabling a zero-copy mode that bypasses intermediate data transfers. The changes involve modifying kernel interfaces, adapting internal state management, and updating the kernel compilation and launch logic to support this more efficient state handling strategy. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces pooled decoding for the Gated Delta Rule, a significant performance optimization for inference that avoids state reallocation. The changes are well-implemented, adding a use_pool_indexing flag and state_indices to map batch items to a state pool. The kernel logic, launcher functions, and public API have been updated accordingly, including necessary input validation and caching mechanism adjustments. I've identified one minor opportunity for code simplification to reduce duplication. Overall, this is a solid and valuable contribution.
flashinfer/gdn_decode.py
Outdated
| o[(i_n, i_t, i_hv, tidx)] = sOutput[tidx] | ||
| else: | ||
| # Padding slot: write zeros to output | ||
| start_v_tiles = batch_inner * num_v_tiles_per_block |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/gdn_decode.py`:
- Around line 1042-1052: Add a bounds check for state_indices when
use_pool_indexing is true: after verifying contiguity, shape and dtype, assert
that all values are >= 0 and < pool_size (e.g., check state_indices.min() >= 0
and state_indices.max() < pool_size or use torch.any to detect OOB) to prevent
GPU OOB accesses; gate this check behind a debug flag (or a cheap runtime
condition like torch.is_grad_enabled() or an explicit debug parameter) so it
only runs in debug builds.
🧹 Nitpick comments (1)
flashinfer/gdn_decode.py (1)
237-240: Non-English comment in kernel code.
# V 方向分 tilesshould be translated to English (e.g.,# Tile along V dimension) for codebase consistency and accessibility.Same issue at line 526.
flashinfer/gdn_decode.py
Outdated
| if use_pool_indexing: | ||
| assert state.is_contiguous(), ( | ||
| "state must be contiguous when using pool indexing (state_indices); " | ||
| "a non-contiguous tensor may silently produce incorrect results" | ||
| ) | ||
| assert state_indices.shape == (B,), ( | ||
| f"Expected state_indices shape [{B}], got {state_indices.shape}" | ||
| ) | ||
| assert state_indices.dtype == torch.int32, ( | ||
| f"state_indices must be int32, got {state_indices.dtype}" | ||
| ) |
There was a problem hiding this comment.
Missing bounds validation: state_indices values are not checked against pool_size.
The assertion block validates shape and dtype of state_indices, but doesn't verify that non-negative values are < pool_size. An out-of-range index would silently cause an OOB global memory access in the kernel, potentially corrupting memory or causing a GPU fault.
Consider adding a debug-mode bounds check:
🛡️ Proposed fix
assert state_indices.dtype == torch.int32, (
f"state_indices must be int32, got {state_indices.dtype}"
)
+ # Validate index bounds (non-negative indices must be < pool_size)
+ valid_mask = state_indices >= 0
+ if valid_mask.any():
+ max_idx = state_indices[valid_mask].max().item()
+ assert max_idx < pool_size, (
+ f"state_indices contains index {max_idx} >= pool_size={pool_size}"
+ )This adds a small overhead, so you may want to gate it behind a debug flag or torch.is_grad_enabled() check.
🤖 Prompt for AI Agents
In `@flashinfer/gdn_decode.py` around lines 1042 - 1052, Add a bounds check for
state_indices when use_pool_indexing is true: after verifying contiguity, shape
and dtype, assert that all values are >= 0 and < pool_size (e.g., check
state_indices.min() >= 0 and state_indices.max() < pool_size or use torch.any to
detect OOB) to prevent GPU OOB accesses; gate this check behind a debug flag (or
a cheap runtime condition like torch.is_grad_enabled() or an explicit debug
parameter) so it only runs in debug builds.
|
@xutizhou can you explain what this PR is about? e.g. adding descriptions about what pool is in GDN. |
updated in the description. |
…lashinfer-ai#2521 Revert gdn_decode.py to base — the state_indices parameter and pool validation in gated_delta_rule_decode_pretranspose belong to PR flashinfer-ai#2521. This PR now only contains BF16 CuTe DSL kernel pool indexing changes in gdn_decode_bf16_state.py. AI-assisted (Claude)
Add comprehensive tests for gated_delta_rule_decode_pretranspose with pool indexing (state_indices parameter): - Test 1: Pooled decode with negative indices (~20% padding) - Test 2: sglang forward_decode calling pattern (unique indices + PAD_SLOT_ID) - Test 3: Pooled vs non-pooled equivalence with identity mapping - Test 4: All-padding batch (output zeros, pool state unchanged) All tests verify output and state against per-sample reference implementation. AI-assisted.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/gdn/test_decode_pooled.py (1)
40-41: Generalize the SM gate to avoid skipping future architectures.Line 40 uses a fixed allowlist (
[9, 10, 11, 12]), which will start skipping newer SM majors even when they should be valid. A lower-bound check is safer.Suggested patch
- if cc[0] not in [9, 10, 11, 12]: - pytest.skip(f"GDN decode requires SM90+ or SM100+, but got SM{cc[0]}{cc[1]}") + if cc[0] < 9: + pytest.skip(f"GDN decode requires SM90+, but got SM{cc[0]}{cc[1]}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gdn/test_decode_pooled.py` around lines 40 - 41, The test currently hardcodes an allowlist of SM majors using cc[0] not in [9, 10, 11, 12], causing future SM majors to be skipped; change this to a lower-bound check (e.g., if cc[0] < 9) so any SM major >=9 is accepted, and update the pytest.skip message to reflect the minimum required SM (use cc[0] to report actual SM and a message like "requires SM9+"). Keep the check and message near the same spot where cc and pytest.skip are used in tests/gdn/test_decode_pooled.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/gdn/test_decode_pooled.py`:
- Around line 229-230: The test simulates a sentinel-slot layout but samples
valid slot indices from [0, pool_size-1], allowing 0 to be treated as a real
request; update the sampling to exclude the sentinel by drawing cache indices
from the range 1..pool_size inclusive (instead of 0..pool_size-1) so slot 0
remains the sentinel. Change the sampling logic that builds cache_indices (and
any related uses around the block referencing pool_size and PAD_SLOT_ID) to use
the corrected range (e.g., start at 1 and go to pool_size) in both places
flagged (around the code that constructs cache_indices and the subsequent
samples).
---
Nitpick comments:
In `@tests/gdn/test_decode_pooled.py`:
- Around line 40-41: The test currently hardcodes an allowlist of SM majors
using cc[0] not in [9, 10, 11, 12], causing future SM majors to be skipped;
change this to a lower-bound check (e.g., if cc[0] < 9) so any SM major >=9 is
accepted, and update the pytest.skip message to reflect the minimum required SM
(use cc[0] to report actual SM and a message like "requires SM9+"). Keep the
check and message near the same spot where cc and pytest.skip are used in
tests/gdn/test_decode_pooled.py.
tests/gdn/test_decode_pooled.py
Outdated
| # - Full pool passed as state (pool_size+1 slots, slot 0 is sentinel) | ||
| # - cache_indices from scheduler, with PAD_SLOT_ID = -1 for padding |
There was a problem hiding this comment.
Sentinel-slot simulation is inconsistent with the stated SGLang layout.
Line 229 says slot 0 is sentinel, but Line 258 samples valid indices from [0, pool_size-1], so 0 can be used as a real request slot. This weakens the “exactly like SGLang” guarantee and can hide off-by-one mapping bugs.
Suggested patch
- cache_indices_int64 = torch.randperm(pool_size, device=device)[:num_valid].to(
+ cache_indices_int64 = (torch.randperm(pool_size, device=device)[:num_valid] + 1).to(
torch.int64
)Also applies to: 257-260
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gdn/test_decode_pooled.py` around lines 229 - 230, The test simulates a
sentinel-slot layout but samples valid slot indices from [0, pool_size-1],
allowing 0 to be treated as a real request; update the sampling to exclude the
sentinel by drawing cache indices from the range 1..pool_size inclusive (instead
of 0..pool_size-1) so slot 0 remains the sentinel. Change the sampling logic
that builds cache_indices (and any related uses around the block referencing
pool_size and PAD_SLOT_ID) to use the corrected range (e.g., start at 1 and go
to pool_size) in both places flagged (around the code that constructs
cache_indices and the subsequent samples).
…ode-pooled # Conflicts: # flashinfer/gdn_decode.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gdn_decode.py (1)
1067-1103:⚠️ Potential issue | 🔴 CriticalCritical: bf16 fast-path ignores pooled indexing semantics.
When
state_indicesis provided and bf16 fast-path conditions match, the code still enters_gated_delta_rule_gdn_decode_klast_bf16_state(Line 1067-1073), but that call has no index mapping and therefore bypasses pooled indirection.🔧 Proposed fix
use_gdn_decode_klast_bf16_state = ( _GDN_DECODE_KLAST_BF16_STATE_AVAILABLE + and not use_pool_indexing and state.dtype == torch.bfloat16 and T in (1, 2, 3, 4) and K == 128 and V == 128 ) + if use_pool_indexing and state.dtype == torch.bfloat16: + raise NotImplementedError( + "state_indices with bfloat16 state is not supported in this decode path yet." + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 1067 - 1103, The bf16 fast-path (_gated_delta_rule_gdn_decode_klast_bf16_state) currently ignores pooled indexing and thus bypasses state indirection when state_indices is set; fix by detecting state_indices before taking the bf16 fast-path and either (a) disable the bf16 fast-path (i.e., set use_gdn_decode_klast_bf16_state to False) so the standard path that respects pooled indexing runs, or (b) materialize a gathered initial_state_source = state.index_select(0, state_indices) (or equivalent gather) and pass that gathered tensor as initial_state_source into _gated_delta_rule_gdn_decode_klast_bf16_state so pooled indexing is honored; ensure you reference state_indices, state, initial_state_source, and _gated_delta_rule_gdn_decode_klast_bf16_state when making the change and preserve the existing output handling logic.
♻️ Duplicate comments (1)
flashinfer/gdn_decode.py (1)
1053-1064:⚠️ Potential issue | 🔴 CriticalCritical: validate
state_indicesupper bound before kernel launch.
state_indicesshape/dtype are checked, but non-negative values are not constrained to< pool_size. This can trigger OOB GPU memory access in pooled mode (Line 1062 onward).🛡️ Proposed fix
if use_pool_indexing: assert state.is_contiguous(), ( "state must be contiguous when using pool indexing (state_indices); " "a non-contiguous tensor may silently produce incorrect results" ) assert state_indices.shape == (B,), ( f"Expected state_indices shape [{B}], got {state_indices.shape}" ) assert state_indices.dtype == torch.int32, ( f"state_indices must be int32, got {state_indices.dtype}" ) + valid_mask = state_indices >= 0 + if torch.any(valid_mask): + max_idx = int(state_indices[valid_mask].max().item()) + assert max_idx < pool_size, ( + f"state_indices contains index {max_idx} >= pool_size={pool_size}" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 1053 - 1064, The code validates state_indices shape/dtype but misses verifying bounds, which can cause OOB GPU access when using pooled mode; in the pooled-path (where use_pool_indexing is True and before the kernel launch that uses state_indices) add a validation that all values in state_indices are >= 0 and < pool_size (and keep the existing dtype/int checks), e.g. check torch.all(state_indices >= 0) and torch.all(state_indices < pool_size) (or equivalent CPU-side/min/max check) and raise/assert with a clear message referencing state_indices and pool_size so invalid indices are caught before the kernel is launched.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/gdn_decode.py`:
- Around line 1067-1103: The bf16 fast-path
(_gated_delta_rule_gdn_decode_klast_bf16_state) currently ignores pooled
indexing and thus bypasses state indirection when state_indices is set; fix by
detecting state_indices before taking the bf16 fast-path and either (a) disable
the bf16 fast-path (i.e., set use_gdn_decode_klast_bf16_state to False) so the
standard path that respects pooled indexing runs, or (b) materialize a gathered
initial_state_source = state.index_select(0, state_indices) (or equivalent
gather) and pass that gathered tensor as initial_state_source into
_gated_delta_rule_gdn_decode_klast_bf16_state so pooled indexing is honored;
ensure you reference state_indices, state, initial_state_source, and
_gated_delta_rule_gdn_decode_klast_bf16_state when making the change and
preserve the existing output handling logic.
---
Duplicate comments:
In `@flashinfer/gdn_decode.py`:
- Around line 1053-1064: The code validates state_indices shape/dtype but misses
verifying bounds, which can cause OOB GPU access when using pooled mode; in the
pooled-path (where use_pool_indexing is True and before the kernel launch that
uses state_indices) add a validation that all values in state_indices are >= 0
and < pool_size (and keep the existing dtype/int checks), e.g. check
torch.all(state_indices >= 0) and torch.all(state_indices < pool_size) (or
equivalent CPU-side/min/max check) and raise/assert with a clear message
referencing state_indices and pool_size so invalid indices are caught before the
kernel is launched.
flashinfer/gdn_decode.py
Outdated
|
|
||
| # Partition for load | ||
| thr_copy_load = tiled_copy_load.get_slice(tidx) | ||
| # V 方向分 tiles |
|
/bot run |
…isted) Replace from_dlpack(h0_source) with make_fake_compact_tensor using cute.sym_int() for the pool_batch dimension, so a single compiled kernel handles any pool_size at runtime. stride_order=(2,1,0) ensures row-major layout with concrete strides for cp.async alignment. Benchmarks show zero performance regression vs compile-time shape: from_dlpack: 0.0306ms median (bs=32, pool=128) sym_int: 0.0307ms median (bs=32, pool=128)
5f15fa8 to
308ad6b
Compare
|
[FAILED] Pipeline #45304647: 6/20 passed |
|
/bot run |
|
@xutizhou is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
…om PR flashinfer-ai#2619 Resolve merge conflicts with upstream main which added pool+indices support via the bf16 fast path (PR flashinfer-ai#2619). Key changes: - Adopt upstream API naming: state_indices -> initial_state_indices, state pool passed via initial_state param - Update test_decode_pooled.py to use new API with bf16 state - Skip negative-index tests (bf16 kernel does not support them yet) - Legacy f32 CuTe DSL path preserved for non-pool usage AI-assisted merge resolution.
Remove 'assert not use_pool' from f32 path — the sym_int approach already handles arbitrary pool_size at runtime with zero overhead. Tests 1/2/4 use f32 state with negative indices (padding support). Test 3 uses bf16 state (routed to bf16 fast path). All 23 pooled decode tests pass. AI-assisted.
|
/bot run |
|
[SUCCESS] Pipeline #45437134: 10/20 passed |
29a2a7f to
01f091d
Compare
Merge test_decode_pooled.py into test_decode_delta_rule.py with: - state_dtype parametrize (bf16 + f32) for pool test - negative indices and all-padding tests (f32 state only) - per-sample Python reference to avoid JIT cache contamination - float32 dt_bias matching SGLang production usage - pytestmark skip preserved to match upstream main CI
01f091d to
e5df67c
Compare
## 📌 Description
This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.
### Background: SGLang's State Pool Architecture
In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:
`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`
where `pool_size` = `max_num_reqs` (maximum concurrent requests).
Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).
### Motivation
The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:
1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state
3. **Scatter** updated states back to pool indices
This adds 2 extra memory copy operations per decode step.
### Changes
This PR adds a `state_indices` parameter for **zero-copy pool access**:
```python
def gated_delta_rule_decode_pretranspose(
q, k, v, beta,
state, # Can be [pool_size, H, K, V] instead of [B, H, K, V]
state_indices, # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
...
)
```
When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)
## 🔍 Related Issues
-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.
* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.
* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
| # Build h0_source: [pool_size*HV, V, K] for kernel | ||
| if use_pool: | ||
| pool_size = initial_state.shape[0] | ||
| assert initial_state.is_contiguous(), ( |
There was a problem hiding this comment.
Do we consider to support non-contiguous state?
There was a problem hiding this comment.
in which situation do we need non-contiguous state?
There was a problem hiding this comment.
vLLM uses non-contiguous state
There was a problem hiding this comment.
vLLM uses non-contiguous state
For non-contiguous states, we should be able to compute the true indices using strides. Once the assert is removed, it can also work with our kernel.
## 📌 Description
This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.
### Background: SGLang's State Pool Architecture
In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:
`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`
where `pool_size` = `max_num_reqs` (maximum concurrent requests).
Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).
### Motivation
The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:
1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state
3. **Scatter** updated states back to pool indices
This adds 2 extra memory copy operations per decode step.
### Changes
This PR adds a `state_indices` parameter for **zero-copy pool access**:
```python
def gated_delta_rule_decode_pretranspose(
q, k, v, beta,
state, # Can be [pool_size, H, K, V] instead of [B, H, K, V]
state_indices, # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
...
)
```
When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)
## 🔍 Related Issues
-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.
* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.
* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
## 📌 Description
This PR adds pool-indexed (indirect) state access to the GDN decode
kernel, enabling zero-copy integration with SGLang's state pool
architecture.
### Background: SGLang's State Pool Architecture
In SGLang, when serving linear attention models (like Qwen3-Next using
Gated Delta Rule), we maintain a **state pool** to store recurrent
states for all active requests:
`ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]`
where `pool_size` = `max_num_reqs` (maximum concurrent requests).
Each active request has a `req_pool_idx` that maps it to a slot in this
pool. The mapping is **not contiguous** - requests come and go, so
indices can be scattered (e.g., a batch of 4 requests might have pool
indices `[3, 7, 12, 25]`).
### Motivation
The current GDN decode kernel expects state with shape `[B, H, K, V]`
where B equals batch size and there's a 1:1 mapping (batch index i →
state index i). To use it with SGLang's pool, we would need to:
1. **Gather** states from pool indices before kernel call
2. Run kernel on contiguous `[B, H, K, V]` state
3. **Scatter** updated states back to pool indices
This adds 2 extra memory copy operations per decode step.
### Changes
This PR adds a `state_indices` parameter for **zero-copy pool access**:
```python
def gated_delta_rule_decode_pretranspose(
q, k, v, beta,
state, # Can be [pool_size, H, K, V] instead of [B, H, K, V]
state_indices, # NEW: int32 tensor [B] mapping batch_idx -> pool_idx
...
)
```
When `state_indices` is provided:
- Kernel uses indirect addressing: `state[state_indices[batch_idx]]`
instead of `state[batch_idx]`
- Negative indices (padding slots for CUDA graph) skip computation and
write zeros to output
- Eliminates gather/scatter overhead + host-side `torch.where` for
padding (~37μs/call)
## 🔍 Related Issues
-
[sgl-project/sglang#18361](sgl-project/sglang#18361)
- FlashInfer K-last GDN integration into SGLang
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
This PR is required for integrating FlashInfer's K-last GDN kernels into
SGLang. The pool indexing feature allows SGLang to directly use its
state pool without gather/scatter overhead.
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Optional pooled (indirect) state access via a new state_indices
parameter enabling zero-copy pooled state handling; negative indices
yield zeroed outputs.
* **Improvements**
* Kernel launch paths, grid sizing, and compiled-kernel caching
differentiate pooled vs. non-pooled modes; APIs and docstrings updated
to propagate pooling flags while maintaining compatibility.
* **Tests**
* New test suite validating pooled decode correctness,
padding/negative-index behavior, state updates, and pooled vs.
non-pooled equivalence.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> vllm uses non-contiguous state for gdn. Make flashinfer also support it ## 🔍 Related Issues #2521 #2687 <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
📌 Description
This PR adds pool-indexed (indirect) state access to the GDN decode kernel, enabling zero-copy integration with SGLang's state pool architecture.
Background: SGLang's State Pool Architecture
In SGLang, when serving linear attention models (like Qwen3-Next using Gated Delta Rule), we maintain a state pool to store recurrent states for all active requests:
ssm_states: [num_layers, pool_size, num_heads, head_dim, head_dim]where
pool_size=max_num_reqs(maximum concurrent requests).Each active request has a
req_pool_idxthat maps it to a slot in this pool. The mapping is not contiguous - requests come and go, so indices can be scattered (e.g., a batch of 4 requests might have pool indices[3, 7, 12, 25]).Motivation
The current GDN decode kernel expects state with shape
[B, H, K, V]where B equals batch size and there's a 1:1 mapping (batch index i → state index i). To use it with SGLang's pool, we would need to:[B, H, K, V]stateThis adds 2 extra memory copy operations per decode step.
Changes
This PR adds a
state_indicesparameter for zero-copy pool access:When
state_indicesis provided:state[state_indices[batch_idx]]instead ofstate[batch_idx]torch.wherefor padding (~37μs/call)🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
This PR is required for integrating FlashInfer's K-last GDN kernels into SGLang. The pool indexing feature allows SGLang to directly use its state pool without gather/scatter overhead.
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Tests